package main

import (
	"bufio"
	"fmt"
	"log"
	"net/http"
	"os"
	"regexp"
	"sort"
	"strconv"

	"gopkg.in/alecthomas/kingpin.v2"
)

const straceBasePath = "https://raw.githubusercontent.com/strace/strace"

// x86_64: https://github.com/strace/strace/blob/v6.4/src/linux/x86_64/syscallent.h
// aarch64: https://github.com/strace/strace/blob/v6.4/src/linux/64/syscallent.h
// common: https://github.com/strace/strace/blob/v6.4/src/linux/generic/syscallent-common.h
var sources = map[string]string{
	"arm64":  "src/linux/64/syscallent.h",
	"amd64":  "src/linux/x86_64/syscallent.h",
	"common": "src/linux/generic/syscallent-common.h",
}

var syscallLineRegex = regexp.MustCompile(`\[(?:BASE_NR \+)?(?:\s+)?(\d+)\] .*SEN\((\w+)\)`)

func main() {
	version := kingpin.Flag("strace.version", "version of strace to use as a source").Required().String()
	kingpin.HelpFlag.Short('h')
	kingpin.Parse()

	syscalls := map[string]map[int]string{}

	for name, path := range sources {
		mapping, err := download(*version, path)
		if err != nil {
			log.Fatalf("Error downloading version %q of path %q: %v", *version, path, err)
		}

		syscalls[name] = mapping
	}

	architectures := []architecture{
		{"arm64", "decoder/syscalls_arm64.go", []map[int]string{syscalls["arm64"], syscalls["common"]}},
		{"amd64", "decoder/syscalls_amd64.go", []map[int]string{syscalls["amd64"], syscalls["common"]}},
	}

	for _, arch := range architectures {
		if err := generate(arch); err != nil {
			log.Fatalf("Error generating %q syscalls: %v", arch.name, err)
		}
	}
}

func download(version string, path string) (map[int]string, error) {
	uri := fmt.Sprintf("%s/%s/%s", straceBasePath, version, path)
	resp, err := http.Get(uri)
	if err != nil {
		return nil, err
	}

	defer resp.Body.Close()

	syscalls := map[int]string{}

	scanner := bufio.NewScanner(resp.Body)
	for scanner.Scan() {
		line := scanner.Text()
		matches := syscallLineRegex.FindStringSubmatch(line)

		if len(matches) > 0 {
			number, err := strconv.Atoi(matches[1])
			if err != nil {
				return nil, fmt.Errorf("error parsing syscall number %q from line %q from %q: %v", matches[1], line, uri, err)
			}

			syscalls[number] = matches[2]
		}
	}

	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("error reading response body for %q: %v", uri, err)
	}

	return syscalls, nil
}

func generate(arch architecture) error {
	file, err := os.Create(arch.path)
	if err != nil {
		return fmt.Errorf("error creating %q: %v", arch.path, err)
	}

	if _, err := file.WriteString(fmt.Sprintf("// Syscall table mapping id to name for %s\n", arch.name)); err != nil {
		return err
	}

	if _, err := file.WriteString("// Generated by 'make syscalls' from the root of the repo. Do not edit.\n\n"); err != nil {
		return err
	}

	if _, err := file.WriteString(fmt.Sprintf("//go:build %s\n\n", arch.name)); err != nil {
		return err
	}

	if _, err := file.WriteString("package decoder\n\n"); err != nil {
		return err
	}

	if _, err := file.WriteString("var syscalls = map[uint64]string{\n"); err != nil {
		return err
	}

	sorted := []syscallPair{}
	for _, syscalls := range arch.syscalls {
		for number, name := range syscalls {
			sorted = append(sorted, syscallPair{number, name})
		}
	}

	sort.Slice(sorted, func(i, j int) bool {
		return sorted[i].number < sorted[j].number
	})

	for _, pair := range sorted {
		number := fmt.Sprintf("%d:", pair.number)
		if _, err := file.WriteString(fmt.Sprintf("\t%-4s %q,\n", number, pair.name)); err != nil {
			return err
		}
	}

	if _, err := file.WriteString("}\n"); err != nil {
		return err
	}

	return nil
}

type architecture struct {
	name     string
	path     string
	syscalls []map[int]string
}

type syscallPair struct {
	number int
	name   string
}
